from Utils import *

def aggregated_filter(df, path):
    print('\n\n\n\n <----------------------> AGGREGATED <----------------------> ')
    logs = df[LMs_columns].to_numpy()
    df['sum'] = np.sum(np.log2(logs), axis=1)
    print(df['sum'].describe())
    plot_histogram(df,'sum',path[0]+'-1.png')
    df['keep'] = np.where(df['sum'] <= path[2], 1, 0)
    print(df['keep'].value_counts())    
    df_filtered = df[df['keep'] == 1]
    plot_histogram(df_filtered,'sum',path[0]+'-2.png')
    print('- KEPT - ')
    print(df_filtered.sort_values(by='sum', ascending=True)[['sum',path[1]]][:20])
    df_filtered = df[df['keep'] == 0]
    print('- OUT - ')
    print(df_filtered.sort_values(by='sum', ascending=False)[['sum',path[1]]][:20])
    df.to_csv(path[0]+'-aggr-filtered.csv', index=False)

def disaggregated_filter(df, path):
    print('\n\n\n\n <----------------------> DISAGGREGATED <----------------------> ')
    df[LMs_columns] = df[LMs_columns].applymap(lambda x: np.log2(x))
    keep = {}
    for LM in LMs_columns:
        print('\n\n\n\n <----------------------> ' + LM)
        df_temp = df[[LM,path[1]]].copy()
        print(df_temp[LM].describe())
        plot_histogram(df_temp,LM,path[0]+'-'+LM+'-'+'1.png')
        df_temp['keep'] = np.where(df_temp[LM] <= df_temp[LM].describe()['mean'], 1, 0)
        print(df_temp['keep'].value_counts())    
        df_filtered = df_temp[df_temp['keep'] == 1]
        keep[LM] = df_filtered[path[1]]
        plot_histogram(df_filtered,LM,path[0]+'-'+LM+'-'+'2.png')
        print('- KEPT - ')
        print(df_filtered.sort_values(by=LM, ascending=True)[[LM,path[1]]][:20])
        df_filtered = df_temp[df_temp['keep'] == 0]
        print('- OUT - ')
        print(df_filtered.sort_values(by=LM, ascending=False)[[LM,path[1]]][:20])

def plot_histogram(df,column,path):
    min_x = df[column].min()
    max_x = df[column].max()
    axes = df.hist(column=column, grid=False, figsize=(12, 8), color='#86bf91', rwidth=0.9)
    ax = axes[0, 0]
    total = len(df)
    for p in ax.patches:
        h = p.get_height()
        if h > 0:
            ax.text(p.get_x() + p.get_width() / 2, h, f'{h / total * 100.0  :.0f} %\n', ha='center', va='center')
    ax.grid(True, axis='y', ls=':', alpha=0.4)
    ax.set_axisbelow(True)
    for dir in ['left', 'right', 'top']:
        ax.spines[dir].set_visible(False)
    ax.tick_params(axis="y", length=0)  
    ax.margins(x=0.02) 
    plt.savefig(path)

path = ['../Social Bias Probing/Stereotypes-w-PPLs', 'stereotype', 130]

df = pd.read_csv(path[0]+'.csv')
aggregated_filter(df, path)
disaggregated_filter(df, path)